import requests
import sqlite3
import json
import os, time, re, sys

# for retrieving API_KEY
from dotenv import load_dotenv
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

from utils.config import load_config
from utils.logger import set_logger

# Load the configuration file
config = load_config()
logger = set_logger()

# Setting the working directory to the directory of the script
os.chdir(os.path.dirname(__file__))

# CONSTANTS
START_DATETIME = config['start_datetime']  # '2023-01-01T00:00:00Z'
END_DATETIME = config['end_datetime']  # '2023-12-31T23:59:59Z'
RADIUS = config['radius']              # "7070m"
COUNTRY_NAME = config['country_name']  # e.g. "mexico"
COUNTRY_CODE = config['country_code']  # e.g. "mx"
YOUTUBE_API_URL = 'https://www.googleapis.com/youtube/v3'

# Convert the start and end datetimes to the format `YYYYMM`
START_YEAR_MONTH = re.sub(r'[^0-9]', '', START_DATETIME[:7])
END_YEAR_MONTH = re.sub(r'[^0-9]', '', END_DATETIME[:7])


# PATHS
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
INPUT_DIR     = os.path.join(PROJECT_ROOT, '../../../data', COUNTRY_NAME)
OUTPUT_DIR    = os.path.join(PROJECT_ROOT, '../../../output', COUNTRY_NAME)
QUERY_DB_PATH = os.path.join(OUTPUT_DIR, f'{START_YEAR_MONTH[:4]}-{START_YEAR_MONTH[4:]}-collection/{START_YEAR_MONTH}_{END_YEAR_MONTH}_01_{COUNTRY_CODE}_query_log.db')
VIDEO_DB_PATH = os.path.join(OUTPUT_DIR, f'{START_YEAR_MONTH[:4]}-{START_YEAR_MONTH[4:]}-collection/{START_YEAR_MONTH}_{END_YEAR_MONTH}_02_{COUNTRY_CODE}_youtube_data.db')

class VideosDatabase:
    def __init__(self, query_db_path=QUERY_DB_PATH, video_db_path=VIDEO_DB_PATH):
        self.query_db_path = query_db_path
        self.video_db_path = video_db_path
        self.api_key = self.get_api_key()
        self.youtube_api_url = YOUTUBE_API_URL
        
        # Create directory if it does not exist
        os.makedirs(os.path.dirname(self.video_db_path), exist_ok=True)
        # Establish a connection to the database
        logger.info(f"Connecting to the database: {self.video_db_path}...")
        try:
            assert os.path.exists(self.video_db_path)
        except AssertionError:
            f"Database not found: {self.video_db_path}. Creating the database first..."
        
        self.conn = sqlite3.connect(video_db_path)
        self.c = self.conn.cursor()

    def __del__(self):
        self.conn.close()

    def get_api_key(self):
        load_dotenv(os.path.join(os.path.dirname(__file__), '..', '.env'))
        # Load the API key from the environment variable
        # Ensure the API key is set in the environment variable
        return os.getenv('API_KEY')

    def create_tables(self):
        # Create the tables
        logger.info("Creating the tables...")
        self.c.execute('''
        CREATE TABLE IF NOT EXISTS video_searches (
            video_id TEXT,
            title TEXT,
            description TEXT,
            publishedAt TEXT,
            latitude REAL,
            longitude REAL,
            radius TEXT,
            PRIMARY KEY (video_id, latitude, longitude, radius)
        )
        ''')

        self.c.execute('''
        CREATE TABLE IF NOT EXISTS videos (
            video_id TEXT PRIMARY KEY,
            publishedAt DATETIME,
            channelId TEXT,
            title TEXT,
            description TEXT,
            channelTitle TEXT,
            categoryId TEXT,
            liveBroadcastContent TEXT,
            defaultLanguage TEXT,
            localized_title TEXT,
            localized_description TEXT,
            defaultAudioLanguage TEXT,
            duration TEXT,
            dimension TEXT,
            definition TEXT,
            caption TEXT,
            licensedContent BOOLEAN,
            projection TEXT,
            hasCustomThumbnail BOOLEAN,
            viewCount INTEGER,
            likeCount INTEGER,
            dislikeCount INTEGER,
            favoriteCount INTEGER,
            commentCount INTEGER,
            recordingDate DATETIME,
            tags TEXT, -- JSON encoded array
            topicIds TEXT, -- JSON encoded array
            relevantTopicIds TEXT, -- JSON encoded array
            topicCategories TEXT -- JSON encoded array
        );
        ''')
        self.conn.commit()
        logger.info("Tables created successfully.")

    def is_video_in_database(self, video_id):
        self.c.execute('SELECT 1 FROM videos WHERE video_id = ?', (video_id,))
        return self.c.fetchone() is not None
    
    def get_video_details(self, video_id):
        if self.is_video_in_database(video_id):
            # logger.info(f"Video {video_id} is already in the database. Skipping...")
            return
        params = {
            'part': 'snippet,contentDetails,statistics,liveStreamingDetails,topicDetails,recordingDetails,localizations',
            'id': video_id,
            'key': self.api_key
        }
        response = requests.get(f"{self.youtube_api_url}/videos", params=params)
        response_json = response.json()
        
        # Debugging: log the response to see what is being returned
        # logger.info("API Response:", response_json)
        
        if 'error' in response_json:
            error = response_json['error']
            if error['code'] == 403:
                logger.info("User Rate Limit Exceeded. Pausing for 24 hours and 1 minute.")
                time.sleep(24 * 3600 + 60)
                return self.get_video_details(video_id)
            
        if 'items' not in response_json:
            logger.info("No 'items' key in response, likely an empty result")
            return None
        
        if not response_json['items']:
            logger.info("Items list is empty, no data found")
            return None
        
        video_data = response_json['items'][0]

        if video_data:
            self.insert_video_data(video_data)

    def insert_video_data(self, video):
        snippet = video['snippet']
        content_details = video['contentDetails']
        statistics = video['statistics']
        live_details = video.get('liveStreamingDetails', {})
        topic_details = video.get('topicDetails', {})
        recording_details = video.get('recordingDetails', {})
        localizations = video.get('localizations', {})
        
        self.c.execute('''
        INSERT OR REPLACE INTO videos (
            video_id, publishedAt, channelId, title, description, channelTitle, categoryId, liveBroadcastContent, defaultLanguage,
            localized_title, localized_description, defaultAudioLanguage, duration, dimension, definition, caption, licensedContent, projection,
            hasCustomThumbnail, viewCount, likeCount, dislikeCount, favoriteCount, commentCount, recordingDate, tags, topicIds, relevantTopicIds,
            topicCategories)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    ''', (
            video['id'],
            snippet['publishedAt'],
            snippet['channelId'],
            snippet['title'],
            snippet['description'],
            snippet['channelTitle'],
            snippet.get('categoryId', ''),
            snippet.get('liveBroadcastContent', ''),
            snippet.get('defaultLanguage', ''),
            localizations.get('default', {}).get('title', ''),
            localizations.get('default', {}).get('description', ''),
            snippet.get('defaultAudioLanguage', ''),
            content_details.get('duration',''),
            content_details.get('dimension',''),
            content_details.get('definition',''),
            content_details.get('caption',''),
            content_details.get('licensedContent',''),
            content_details.get('projection',''),
            content_details.get('hasCustomThumbnail', False),
            statistics.get('viewCount', 0),
            statistics.get('likeCount', 0),
            statistics.get('dislikeCount', 0),
            statistics.get('favoriteCount', 0),
            statistics.get('commentCount', 0),
            recording_details.get('recordingDate', ''),
            json.dumps(snippet.get('tags', [])),
            json.dumps(topic_details.get('topicIds', [])),
            json.dumps(topic_details.get('relevantTopicIds', [])),
            json.dumps(topic_details.get('topicCategories', []))
        ))
        self.conn.commit()
                       
    def fetch_all_pages(self, url, params):
        all_items = []
        pagecount=1
        while True: # Use a loop to continue fetching until there's no nextPageToken
            response = requests.get(url, params=params)
            data = response.json()
            # Check for errors in response
            if 'error' in data:
                error = data['error']
                logger.info(f"Error fetching data: {data['error']['message']}")
                # Check for rate limit exceeded error
                if error['code'] == 403:
                    logger.info("User Rate Limit Exceeded. Pausing for 24 hours and 1 minute.")
                    time.sleep(24 * 3600 + 60) # Sleep for 24 hours and 1 minute
                    continue # Retry the request after sleeping
                else:
                    break # Exit loop for other errors
            
            all_items.extend(data.get('items', []))
            
            nextPageToken = data.get('nextPageToken')
            if not nextPageToken:
                break # No more pages to fetch, exit loop
            
            params['pageToken'] = nextPageToken # Set the nextPageToken for the next fetch
            print(f"Total results on pageInfo {pagecount}: {data['pageInfo']['totalResults']}")
            pagecount+=1
        logger.info(f"Location: {params['location']}")
        logger.info(f"Published After: {params['publishedAfter']}")
        logger.info(f"Published Before: {params['publishedBefore']}")
        logger.info(f"No. of pages fetched: {pagecount}")
        logger.info(f"Total no. of videos reported: {data['pageInfo']['totalResults']}")
        logger.info(f"Actual no. of videos fetched: {len(all_items)}")
        logger.info(f"No. of vidoes missed: {data['pageInfo']['totalResults'] - len(all_items)}\n.")

        return all_items

    def search_videos_by_location(self, latitude, longitude, publishedAfter, publishedBefore, radius=RADIUS, max_results=50):
        params = {
            'part': 'snippet',
            'type': 'video',
            'location': f"{latitude},{longitude}",
            'locationRadius': radius,
            'maxResults': max_results,
            'publishedAfter': publishedAfter,
            'publishedBefore': publishedBefore,
            'key': self.api_key
        }
        videos = self.fetch_all_pages(f"{self.youtube_api_url}/search", params)
        video_ids = []
        for item in videos:
            if 'id' in item and 'videoId' in item['id']:
                video_data = {
                    'video_id': item['id']['videoId'],
                    'title': item['snippet'].get('title', ''),
                    'description': item['snippet'].get('description', ''),
                    'publishedAt': item['snippet'].get('publishedAt', ''),
                    'latitude': latitude,
                    'longitude': longitude,
                    'radius': radius,
                }
                self.c.execute('''
                INSERT OR REPLACE INTO video_searches (video_id, title, description, publishedAt, latitude, longitude, radius)
                VALUES (?, ?, ?, ?, ?, ?, ?)
                ''', (video_data['video_id'], video_data['title'], video_data['description'], video_data['publishedAt'], latitude, longitude, radius))
                self.conn.commit()
                video_ids.append(video_data['video_id'])
            else:
                logger.info("Missing video ID in API response item:", item)
        
        return video_ids

    def read_locations_and_search_videos(self):
        conn = sqlite3.connect(self.query_db_path)
        c = conn.cursor()
        try:
            c.execute('SELECT id, latitude, longitude, publishedAfter, publishedBefore FROM query WHERE query_completed = 0')
            locations = c.fetchall()
            count=0
            for location in locations:            
                loc_id, latitude, longitude, publishedAfter, publishedBefore = location
                video_ids = self.search_videos_by_location(latitude, longitude, publishedAfter, publishedBefore)
                for video_id in video_ids:
                    logger.info(f"Video id: {video_id}")
                    self.get_video_details(video_id)  # Fetch and store video details
                
                # Update the query_completed status to 1 after processing
                c.execute('UPDATE query SET query_completed = 1 WHERE id = ?', (loc_id,))
                conn.commit()
                count+=1
                logger.info(f"Location {loc_id} processed successfully.")
                logger.info(f"\nSEARCH Progress: {count}/{len(locations)}.\n")
                # break  # For testing, remove this line to process all locations
        except Exception as e:
            conn.rollback()
            logger.info(f"Error during database operation: {e}")
            raise # Re-raise exception to be handled by the outer function
        finally:
            conn.close()


def main():
    retry_delay = 15  # seconds
    max_retries = 5  # max number of retries
    attempt = 0

    videos_db = VideosDatabase()
    videos_db.create_tables()

    while attempt < max_retries:
        try:
            videos_db.read_locations_and_search_videos()
            logger.info("Successfully processed all locations.")
            break  # Exit loop if successful
        except Exception as e:
            attempt += 1
            logger.info(f"An error occurred: {e}. Retrying in {retry_delay} seconds... (Attempt {attempt}/{max_retries})")
            time.sleep(retry_delay)  # Wait before retrying
            if attempt == max_retries:
                logger.info("Maximum retry attempts reached. Exiting.")
                raise  # Re-raise the last exception after final attempt


if __name__ == '__main__':
    main()
